{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "# -*- coding:utf-8 -*-\n",
    "# Created by LuoJie at 11/23/19\n",
    "\n",
    "from utils.config import save_wv_model_path, vocab_path\n",
    "import tensorflow as tf\n",
    "from utils.gpu_utils import config_gpu\n",
    "from tensorflow.keras.models import Model, Sequential\n",
    "from tensorflow.keras.layers import GRU, Input, Dense, TimeDistributed, Activation, RepeatVector, Bidirectional\n",
    "from tensorflow.keras.layers import Embedding\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.losses import sparse_categorical_crossentropy\n",
    "import tensorflow as tf\n",
    "from utils.wv_loader import load_embedding_matrix, Vocab\n",
    "\n",
    "from tensorflow.python.ops import nn_ops\n",
    "\n",
    "\n",
    "class Encoder(tf.keras.Model):\n",
    "    def __init__(self, vocab_size, embedding_dim, embedding_matrix, enc_units, batch_sz):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.batch_sz = batch_sz\n",
    "        self.enc_units = enc_units\n",
    "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, weights=[embedding_matrix],\n",
    "                                                   trainable=False)\n",
    "        self.gru = tf.keras.layers.GRU(self.enc_units,\n",
    "                                       return_sequences=True,\n",
    "                                       return_state=True,\n",
    "                                       recurrent_initializer='glorot_uniform')\n",
    "\n",
    "    def __call__(self, x, hidden):\n",
    "        x = self.embedding(x)\n",
    "        output, hidden = self.gru(x, initial_state=hidden)\n",
    "        return output, hidden\n",
    "\n",
    "    def initialize_hidden_state(self):\n",
    "        return tf.zeros((self.batch_sz, self.enc_units))\n",
    "\n",
    "\n",
    "def masked_attention(enc_padding_mask, attn_dist):\n",
    "    \"\"\"Take softmax of e then apply enc_padding_mask and re-normalize\"\"\"\n",
    "    attn_dist = tf.squeeze(attn_dist, axis=2)\n",
    "    mask = tf.cast(enc_padding_mask, dtype=attn_dist.dtype)\n",
    "    attn_dist *= mask  # apply mask\n",
    "    masked_sums = tf.reduce_sum(attn_dist, axis=1)  # shape (batch_size)\n",
    "    attn_dist = attn_dist / tf.reshape(masked_sums, [-1, 1])  # re-normalize\n",
    "    attn_dist = tf.expand_dims(attn_dist, axis=2)\n",
    "    return attn_dist\n",
    "\n",
    "\n",
    "class BahdanauAttention(tf.keras.layers.Layer):\n",
    "    def __init__(self, units):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.W_s = tf.keras.layers.Dense(units)\n",
    "        self.W_h = tf.keras.layers.Dense(units)\n",
    "        self.W_c = tf.keras.layers.Dense(units)\n",
    "        self.V = tf.keras.layers.Dense(1)\n",
    "\n",
    "    def __call__(self, dec_hidden, enc_output, enc_pad_mask, use_coverage, prev_coverage=None):\n",
    "        # query为上次的GRU隐藏层\n",
    "        # values为编码器的编码结果enc_output\n",
    "        # 在seq2seq模型中，St是后面的query向量，而编码过程的隐藏状态hi是values。\n",
    "\n",
    "        # hidden shape == (batch_size, hidden size)\n",
    "        # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n",
    "        # we are doing this to perform addition to calculate the score\n",
    "        hidden_with_time_axis = tf.expand_dims(dec_hidden, 1)\n",
    "\n",
    "        if use_coverage and prev_coverage is not None:\n",
    "            # Multiply coverage vector by w_c to get coverage_features.\n",
    "            # self.W_s(values) [batch_sz, max_len, units] self.W_h(hidden_with_time_axis) [batch_sz, 1, units]\n",
    "            # self.W_c(prev_coverage) [batch_sz, max_len, units]  score [batch_sz, max_len, 1]\n",
    "            score = self.V(tf.nn.tanh(self.W_s(enc_output) + self.W_h(hidden_with_time_axis) + self.W_c(prev_coverage)))\n",
    "            # attention_weights shape (batch_size, max_len, 1)\n",
    "\n",
    "            # attention_weights sha== (batch_size, max_length, 1)\n",
    "            attention_weights = tf.nn.softmax(score, axis=1)\n",
    "\n",
    "            # attention_weights = masked_attention(enc_pad_mask, attention_weights)\n",
    "            coverage = attention_weights + prev_coverage\n",
    "        else:\n",
    "            # score shape == (batch_size, max_length, 1)\n",
    "            # we get 1 at the last axis because we are applying score to self.V\n",
    "            # the shape of the tensor before applying self.V is (batch_size, max_length, units)\n",
    "            # 计算注意力权重值\n",
    "            score = self.V(tf.nn.tanh(\n",
    "                self.W_s(enc_output) + self.W_h(hidden_with_time_axis)))\n",
    "\n",
    "            attention_weights = tf.nn.softmax(score, axis=1)\n",
    "            # attention_weights = masked_attention(enc_pad_mask, attention_weights)\n",
    "            if use_coverage:\n",
    "                coverage = attention_weights\n",
    "            else:\n",
    "                coverage = []\n",
    "\n",
    "        # # 使用注意力权重*编码器输出作为返回值，将来会作为解码器的输入\n",
    "        # context_vector shape after sum == (batch_size, hidden_size)\n",
    "        context_vector = attention_weights * enc_output\n",
    "        context_vector = tf.reduce_sum(context_vector, axis=1)\n",
    "        return context_vector, tf.squeeze(attention_weights, -1), coverage\n",
    "\n",
    "\n",
    "class Decoder(tf.keras.Model):\n",
    "    def __init__(self, vocab_size, embedding_dim, embedding_matrix, dec_units, batch_sz):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.batch_sz = batch_sz\n",
    "        self.dec_units = dec_units\n",
    "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, weights=[embedding_matrix],\n",
    "                                                   trainable=False)\n",
    "        self.gru = tf.keras.layers.GRU(self.dec_units,\n",
    "                                       return_sequences=True,\n",
    "                                       return_state=True,\n",
    "                                       recurrent_initializer='glorot_uniform')\n",
    "        self.fc = tf.keras.layers.Dense(vocab_size, activation=tf.keras.activations.softmax)\n",
    "\n",
    "    def __call__(self, x, hidden, enc_output, context_vector):\n",
    "        # 使用上次的隐藏层（第一次使用编码器隐藏层）、编码器输出计算注意力权重\n",
    "        # enc_output shape == (batch_size, max_length, hidden_size)\n",
    "\n",
    "        # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n",
    "        # print('x:{}'.format(x))\n",
    "        x = self.embedding(x)\n",
    "\n",
    "        # 将上一循环的预测结果跟注意力权重值结合在一起作为本次的GRU网络输入\n",
    "        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n",
    "        dec_x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
    "\n",
    "        # passing the concatenated vector to the GRU\n",
    "        output, state = self.gru(dec_x)\n",
    "\n",
    "        # output shape == (batch_size * 1, hidden_size)\n",
    "        output = tf.reshape(output, (-1, output.shape[2]))\n",
    "\n",
    "        # output shape == (batch_size, vocab)\n",
    "        prediction = self.fc(output)\n",
    "\n",
    "        return dec_x, prediction, state\n",
    "\n",
    "\n",
    "class Pointer(tf.keras.layers.Layer):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(Pointer, self).__init__()\n",
    "        self.w_s_reduce = tf.keras.layers.Dense(1)\n",
    "        self.w_i_reduce = tf.keras.layers.Dense(1)\n",
    "        self.w_c_reduce = tf.keras.layers.Dense(1)\n",
    "\n",
    "    def __call__(self, context_vector, dec_hidden, dec_inp):\n",
    "        return tf.nn.sigmoid(self.w_s_reduce(dec_hidden) + self.w_c_reduce(context_vector) + self.w_i_reduce(dec_inp))\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    # GPU资源配置\n",
    "    config_gpu()\n",
    "    # 读取vocab训练\n",
    "    vocab = Vocab(vocab_path)\n",
    "    # 计算vocab size\n",
    "    vocab_size = vocab.count\n",
    "\n",
    "    # 使用GenSim训练好的embedding matrix\n",
    "    embedding_matrix = load_embedding_matrix()\n",
    "\n",
    "    enc_max_len = 200\n",
    "    dec_max_len = 41\n",
    "    batch_size = 64\n",
    "    embedding_dim = 300\n",
    "    units = 1024\n",
    "\n",
    "    # 编码器结构\n",
    "    encoder = Encoder(vocab_size, embedding_dim, embedding_matrix, units, batch_size)\n",
    "    # encoder input\n",
    "    enc_inp = tf.ones(shape=(batch_size, enc_max_len), dtype=tf.int32)\n",
    "    # decoder input\n",
    "    dec_inp = tf.ones(shape=(batch_size, 1, dec_max_len), dtype=tf.int32)\n",
    "    # enc pad mask\n",
    "    enc_pad_mask = tf.ones(shape=(batch_size, enc_max_len), dtype=tf.int32)\n",
    "\n",
    "    # encoder hidden\n",
    "    enc_hidden = encoder.initialize_hidden_state()\n",
    "\n",
    "    enc_output, enc_hidden = encoder(enc_inp, enc_hidden)\n",
    "    # 打印结果\n",
    "    print('Encoder output shape: (batch size, sequence length, units) {}'.format(enc_output.shape))\n",
    "    print('Encoder Hidden state shape: (batch size, units) {}'.format(enc_hidden.shape))\n",
    "\n",
    "    attention_layer = BahdanauAttention(10)\n",
    "    context_vector, attention_weights, coverage = attention_layer(enc_hidden, enc_output, enc_pad_mask)\n",
    "\n",
    "    print(\"Attention context_vector shape: (batch size, units) {}\".format(context_vector.shape))\n",
    "    print(\"Attention weights shape: (batch_size, sequence_length, 1) {}\".format(attention_weights.shape))\n",
    "    print(\"Attention coverage shape: (batch_size, ) {}\".format(coverage.shape))\n",
    "\n",
    "    decoder = Decoder(vocab_size, embedding_dim, embedding_matrix, units, batch_size)\n",
    "\n",
    "    dec_x, dec_out, dec_hidden, = decoder(tf.random.uniform((64, 1)),\n",
    "                                          enc_hidden,\n",
    "                                          enc_output,\n",
    "                                          context_vector)\n",
    "    print('Decoder output shape: (batch_size, vocab size) {}'.format(dec_out.shape))\n",
    "    print('Decoder dec_x shape: (batch_size, 1,embedding_dim + units) {}'.format(dec_x.shape))\n",
    "\n",
    "    pointer = Pointer()\n",
    "    p_gen = pointer(context_vector, dec_hidden, dec_inp)\n",
    "    print('Pointer p_gen shape: (batch_size,1) {}'.format(p_gen.shape))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "import os\n",
    "os.chdir('E:\\GitHub\\QA-abstract-and-reasoning')\n",
    "sys.path.append('E:\\GitHub\\QA-abstract-and-reasoning')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Physical GPUs, 1 Logical GPUs\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from pgn.layers import Encoder, Decoder, Pointer, BahdanauAttention\n",
    "from pgn.model import PGN\n",
    "from pgn.batcher import batcher\n",
    "from utils.saveLoader import load_embedding_matrix\n",
    "from utils.saveLoader import Vocab\n",
    "from utils.config import VOCAB_PAD\n",
    "from utils.config_gpu import config_gpu\n",
    "config_gpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模拟model.call里的情况\n",
    "### 构建输入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run utils/params.py\n",
    "# 产生输入数据\n",
    "vocab = Vocab(VOCAB_PAD)\n",
    "ds = batcher(vocab, params)\n",
    "enc_data, dec_data = next(iter(ds))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## model用到的层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_matrix = load_embedding_matrix()\n",
    "encoder = Encoder(vocab_size = params[\"vocab_size\"],\n",
    "                               embedding_dim = params[\"embed_size\"],\n",
    "                               embedding_matrix = embedding_matrix,\n",
    "                               enc_units = params[\"enc_units\"],\n",
    "                               batch_size = params[\"batch_size\"])\n",
    "attention = BahdanauAttention(units = params[\"attn_units\"])\n",
    "decoder = Decoder(vocab_size =  params[\"vocab_size\"],\n",
    "                               embedding_dim = params[\"embed_size\"],\n",
    "                               embedding_matrix = embedding_matrix,\n",
    "                               dec_units = params[\"dec_units\"],\n",
    "                               batch_size = params[\"batch_size\"])\n",
    "pointer = Pointer()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## model.call的参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_inp = enc_data[\"enc_input\"]\n",
    "dec_inp = dec_data[\"dec_input\"]\n",
    "enc_extended_inp = enc_data[\"extended_enc_input\"]\n",
    "batch_oov_len = enc_data[\"max_oov_len\"]\n",
    "enc_pad_mask = enc_data[\"enc_mask\"]\n",
    "use_coverage = True\n",
    "prev_coverage=None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = []\n",
    "attentions = []\n",
    "p_gens = []\n",
    "coverages = []\n",
    "\n",
    "# enc_output (batch_size, enc_len, enc_units)\n",
    "# enc_hidden (batch_size, enc_units)\n",
    "enc_output, enc_hidden = encoder(enc_inp)\n",
    "dec_hidden = enc_hidden\n",
    "\n",
    "# context_vector (batch_size, enc_units)\n",
    "# coverage_ret (batch_size, enc_len, 1)\n",
    "context_vector, _, coverage_ret = attention(dec_hidden,\n",
    "                                             enc_output,\n",
    "                                             enc_pad_mask,\n",
    "                                             use_coverage,\n",
    "                                             prev_coverage)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 进入调用decoder的循环"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "for t in range(dec_inp.shape[1]):\n",
    "\n",
    "    # dec_inp[:, t] (batch_size, )\n",
    "    # dec_x (batch_size, 1, embedding_dim + dec_units)\n",
    "    # dec_pred (batch_size, vocab_size)\n",
    "    # dec_hidden (batch_size, dec_units)\n",
    "    dec_x, dec_pred, dec_hidden = decoder(tf.expand_dims(dec_inp[:, t], 1),\n",
    "                                           dec_hidden,\n",
    "                                           enc_output,\n",
    "                                           context_vector)\n",
    "    context_vector, attn, coverage_ret = attention(dec_hidden,\n",
    "                                                        enc_output,\n",
    "                                                        enc_pad_mask,\n",
    "                                                        use_coverage,\n",
    "                                                        coverage_ret)\n",
    "    # p_gen (batch_size, 1)\n",
    "    p_gen = pointer(context_vector, dec_hidden, tf.squeeze(dec_x, axis=1))\n",
    "    coverages.append(coverage_ret)\n",
    "    attentions.append(attn)\n",
    "    predictions.append(dec_pred)\n",
    "    p_gens.append(p_gen)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **cal_final_dist**\n",
    "ps 要在上面的函数循环之后才进入\n",
    "\n",
    "这一步才是今天的重点，先宏观的看下这个函数怎么用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "code_folding": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(32233, 5)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 宏观了解\n",
    "from pgn.model import _calc_final_dist\n",
    "# final_dists list (batch_size, vocab_size+batch_oov_len)\n",
    "final_dists = _calc_final_dist(enc_extended_inp,\n",
    "                               predictions,\n",
    "                               attentions,\n",
    "                               p_gens,\n",
    "                               batch_oov_len,\n",
    "                               params[\"vocab_size\"],\n",
    "                               params[\"batch_size\"])\n",
    "vocab.count, enc_data[\"max_oov_len\"].numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 进入内部前配置参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "_enc_batch_extend_vocab = enc_extended_inp  # (batch_size, enc_len)\n",
    "vocab_dists = predictions  # (batch_size, vocab_size)\n",
    "attn_dists = attentions  # (batch_size, enc_len)\n",
    "p_gens = p_gens  # (batch_size, 1)\n",
    "batch_oov_len = batch_oov_len  # 5(for example)\n",
    "vocab_size = params[\"vocab_size\"]  # 32233\n",
    "batch_size = params[\"batch_size\"]  # 64 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 原版内部"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "# 计算加权后的原词表分布\n",
    "# shape保持不变\n",
    "vocab_dists = [p_gen * dist for (p_gen, dist) in zip(p_gens, vocab_dists)]\n",
    "# 计算注意力分布\n",
    "# shape保持不变\n",
    "attn_dists = [(1 - p_gen) * dist for (p_gen, dist) in zip(p_gens, attn_dists)]\n",
    "# extended_vocab_size 扩展后的词表大小\n",
    "extended_vocab_size = vocab_size + batch_oov_len \n",
    "# 给vocab_dists额外拼接的全0列\n",
    "# extra_zeros (batch_size, batch_oov_len)\n",
    "extra_zeros = tf.zeros((batch_size, batch_oov_len))\n",
    "\n",
    "# 在vocab_dists尾部拼接oov全0列\n",
    "# vocab_dists_extended (batch_size, vocab_size+batch_oov_len)\n",
    "vocab_dists_extended = [tf.concat(axis=1, values=[dist, extra_zeros]) for dist in vocab_dists]\n",
    "\n",
    "# 0-63 的索引数组\n",
    "batch_nums = tf.range(0, limit=batch_size)\n",
    "# batch_nums (batch_size, 1)\n",
    "batch_nums = tf.expand_dims(batch_nums, 1)\n",
    "# attn_len value: enc_len(200)\n",
    "attn_len = tf.shape(_enc_batch_extend_vocab)[1]\n",
    "# batch_nums (batch_size, enc_len) \n",
    "batch_nums = tf.tile(batch_nums, [1, attn_len])\n",
    "\n",
    "# 给每一批次的样本加上了批次号标签\n",
    "# indices (batch_size, enc_len, 2)\n",
    "indices = tf.stack((batch_nums, _enc_batch_extend_vocab), axis=2)\n",
    "\n",
    "# extended_vocab_size = vocab_size + batch_oov_len \n",
    "shape = [batch_size, extended_vocab_size]\n",
    "\n",
    "attn_dists_projected = [tf.scatter_nd(indices, copy_dist, shape) for copy_dist in attn_dists]\n",
    "\n",
    "final_dists2 = [vocab_dist + copy_dist for (vocab_dist, copy_dist) in\n",
    "                   zip(vocab_dists_extended, attn_dists_projected)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 改进版\n",
    "感觉不需要这么多列表生成式\n",
    "\n",
    "好吧是需要的\n",
    "\n",
    "但是我感觉可以改成不需要\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "# attentions_ (batch_size, dec_len, enc_len)\n",
    "# predictions_ (batch_size, dec_len, vocab_size)\n",
    "# p_gens_ (batch_size, dec_len, 1)\n",
    "attentions_ = tf.stack(attentions, 1)\n",
    "predictions_ = tf.stack(predictions, 1)\n",
    "p_gens_ = tf.stack(p_gens, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\n",
    "P(w) = p_{gen}P_{vocab}(w)+(1-P_{gen})\\sum_{i:w_i=w}a_i^t\n",
    "$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "# 确定的修改代码\n",
    "# 先计算公式的左半部分\n",
    "# _vocab_dists_pgn (batch_size, dec_len, vocab_size)\n",
    "_vocab_dists_pgn = predictions_ * p_gens_\n",
    "# 根据oov表的长度补齐原词表\n",
    "# _extra_zeros (batch_size, dec_len, batch_oov_len)\n",
    "_extra_zeros = tf.zeros((batch_size, p_gens_.shape[1],batch_oov_len))\n",
    "# 拼接后公式的左半部分完成了\n",
    "# _vocab_dists_extended (batch_size, dec_len, vocab_size+batch_oov_len)\n",
    "_vocab_dists_extended = tf.concat([_vocab_dists_pgn, _extra_zeros], axis=-1)\n",
    "\n",
    "# 公式右半部分\n",
    "# 乘以权重后的注意力\n",
    "# _attn_dists_pgn (batch_size, dec_len, enc_len)\n",
    "_attn_dists_pgn = attentions_ * (1-p_gens_)\n",
    "# 拓展后的长度\n",
    "_extended_vocab_size = vocab_size + batch_oov_len\n",
    "\n",
    "# 要更新的数组 _attn_dists_pgn\n",
    "# 更新之后数组的形状与 公式左半部分一致\n",
    "# shape=[batch_size, dec_len, vocab_size+batch_oov_len]\n",
    "shape = _vocab_dists_extended.shape\n",
    "\n",
    "enc_len = tf.shape(_enc_batch_extend_vocab)[1]\n",
    "dec_len = tf.shape(_vocab_dists_extended)[1]\n",
    "\n",
    "# batch_nums (batch_size, )\n",
    "batch_nums = tf.range(0, limit=batch_size)\n",
    "# batch_nums (batch_size, 1)\n",
    "batch_nums = tf.expand_dims(batch_nums, 1)\n",
    "# batch_nums (batch_size, 1, 1)\n",
    "batch_nums = tf.expand_dims(batch_nums, 2)\n",
    "\n",
    "# tile 在第1,2个维度上分别复制batch_nums dec_len,enc_len次\n",
    "# batch_nums (batch_size, dec_len, enc_len) \n",
    "batch_nums = tf.tile(batch_nums, [1, dec_len, enc_len])\n",
    "\n",
    "# (dec_len, )\n",
    "dec_len_nums = tf.range(0, limit=dec_len)\n",
    "# (1, dec_len)\n",
    "dec_len_nums = tf.expand_dims(dec_len_nums, 0)\n",
    "# (1, dec_len, 1)\n",
    "dec_len_nums = tf.expand_dims(dec_len_nums, 2)\n",
    "# tile是用来在不同维度上复制张量的\n",
    "# dec_len_nums (batch_size, dec_len, enc_len) \n",
    "dec_len_nums = tf.tile(dec_len_nums, [batch_size, 1, enc_len])\n",
    "\n",
    "# _enc_batch_extend_vocab_expand (batch_size, 1, enc_len)\n",
    "_enc_batch_extend_vocab_expand = tf.expand_dims(_enc_batch_extend_vocab, 1)\n",
    "# _enc_batch_extend_vocab_expand (batch_size, dec_len, enc_len) \n",
    "_enc_batch_extend_vocab_expand = tf.tile(_enc_batch_extend_vocab_expand, [1, dec_len, 1])\n",
    "\n",
    "# 因为要scatter到一个3Dtensor上，所以最后一维是3\n",
    "# indices (batch_size, dec_len, enc_len, 3) \n",
    "indices = tf.stack((batch_nums, \n",
    "                    dec_len_nums, \n",
    "                    _enc_batch_extend_vocab_expand), \n",
    "                   axis=3)\n",
    "\n",
    "# 开始更新\n",
    "attn_dists_projected = tf.scatter_nd(indices, _attn_dists_pgn, shape)\n",
    "# 至此完成了公式的右半边\n",
    "\n",
    "# 计算最终分布\n",
    "final_dists = _vocab_dists_extended + attn_dists_projected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_dists = tf.stack(final_dists, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "80466048"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "64*39*32238"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5103"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "80466048-80460945"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "80460945"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(final_dists2 == final_dists).numpy().sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "两种方法计算得到的final_dist结果一致"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 调试中"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "难点是公式的这部分 $\\sum_{i:w_i=w}a_i^t$\n",
    "\n",
    "如何从代码层面把相同词的注意力加和到一起\n",
    "\n",
    "我看[有篇文章](https://blog.csdn.net/zlrai5895/article/details/80551056)说：**函数`tf.scatter_nd`更新应用的顺序是非确定性的，所以如果indices包含重复项的话，则输出将是不确定的。**\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 学习tf.scatter_nd\n",
    "```\n",
    "tf.scatter_nd(\n",
    "    indices,\n",
    "    updates,\n",
    "    shape,\n",
    "    name=None\n",
    ")\n",
    "\n",
    "```\n",
    "[这篇文章](https://blog.csdn.net/zlrai5895/article/details/80551056)的例子和图片比较丰富，是翻译了[官方文档](https://tensorflow.google.cn/api_docs/python/tf/scatter_nd?hl=en&version=stable)的中文版解释。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 基本用法\n",
    "`updates`是原始的张量\n",
    "\n",
    "`shape` 是要变成怎样长度的张量\n",
    "\n",
    "`indices` 是原始张量在新张量的位置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=17, shape=(8,), dtype=int32, numpy=array([0, 1, 0, 2, 0, 3, 0, 4])>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "indices = tf.constant([[1], [3], [5], [7]])\n",
    "updates = tf.constant([1, 2, 3, 4])\n",
    "shape = tf.constant([8])\n",
    "\n",
    "scatter = tf.scatter_nd(indices, updates, shape)\n",
    "scatter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### indices重复项的验证情况\n",
    "可以看到，如果indices包含重复项，那这些重复项的数字会加和然后放到指定位置\n",
    "\n",
    "之所以说可能会数值不确定，是因为浮点数的数值精度问题，按不同顺序加和的数据，得到的最终数值可能不同，但其实大差不差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=29, shape=(8,), dtype=int32, numpy=array([0, 3, 0, 7, 0, 0, 0, 0])>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "indices = tf.constant([[1], [1], [3], [3]])\n",
    "scatter = tf.scatter_nd(indices, updates, shape)\n",
    "scatter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 问题迁移step1\n",
    "先从简单的例子开始\n",
    "_attn_dists_pgn 是 `(64, 40, 200)`维度的注意力数值 \n",
    "\n",
    "需要转变为`(64, 40, vocab_size+batch_oov_len)`\n",
    "\n",
    "先实现(200,) --> (vocab_size+batch_oov_len,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([64, 40, 200])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# step1\n",
    "# 终极目标\n",
    "_attn_dists_pgn.shape\n",
    "\n",
    "# 构建要更新的数组\n",
    "step1_updates = _attn_dists_pgn[0,0,:]\n",
    "step1.shape\n",
    "\n",
    "# 用于更新的索引，维度要比updates高一维\n",
    "step1_indices = tf.expand_dims(_enc_batch_extend_vocab[0], axis=-1)\n",
    "# 更新完成后的形状\n",
    "step1_shape = tf.constant([vocab_size]) + batch_oov_len\n",
    "step1_indices.shape, step1_shape\n",
    "\n",
    "# 开始更新\n",
    "scatter = tf.scatter_nd(step1_indices, step1_updates, step1_shape)\n",
    "scatter.numpy().sum()\n",
    "\n",
    "# 更新前注意力之和\n",
    "# 在小数点5位之后有微笑的差别，不知道是否关键\n",
    "step1_updates.numpy().sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### step2 (64, 200) -> (64, vocab_size+batch_oov_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([64, 200])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# step2 构建要更新的数组\n",
    "step2_updates = _attn_dists_pgn[:,0,:]\n",
    "step2_updates.shape\n",
    "\n",
    "# 0-63 的索引数组\n",
    "# batch_nums (batch_size,)\n",
    "batch_nums = tf.range(0, limit=batch_size)\n",
    "\n",
    "# batch_nums (batch_size, 1)\n",
    "batch_nums = tf.expand_dims(batch_nums, 1)\n",
    "\n",
    "# attn_len : enc_len(200)\n",
    "attn_len = tf.shape(_enc_batch_extend_vocab)[1]\n",
    "\n",
    "# batch_nums (batch_size, enc_len) \n",
    "batch_nums = tf.tile(batch_nums, [1, attn_len])\n",
    "\n",
    "# 给每一批次的样本加上了批次号标签\n",
    "# indices (batch_size, enc_len, 2)\n",
    "indices = tf.stack((batch_nums, _enc_batch_extend_vocab), axis=2)\n",
    "\n",
    "\n",
    "# 用于更新的索引，维度要比updates高一维\n",
    "step2_indices = indices\n",
    "# 更新完成后的形状\n",
    "_extended_vocab_size = vocab_size + batch_oov_len\n",
    "step2_shape =[batch_size, _extended_vocab_size]\n",
    "step2_indices.shape, step2_shape\n",
    "\n",
    "# 开始更新\n",
    "scatter = tf.scatter_nd(step2_indices, step2_updates, step2_shape)\n",
    "scatter.numpy().sum()\n",
    "\n",
    "step2_updates.numpy().sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### step3 (64, 40, 200)-> ((64, 40, vocab_size+batch_oov_len))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorShape([64, 39, 200]), TensorShape([64, 39, 32242]))"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# step3 构建要更新的数组\n",
    "step3_updates = _attn_dists_pgn[:]\n",
    "step3_shape = _vocab_dists_extended.shape\n",
    "# step3_updates.shape, step3_shape\n",
    "\n",
    "enc_len = tf.shape(_enc_batch_extend_vocab)[1]\n",
    "dec_len = _vocab_dists_extended.shape[1]\n",
    "\n",
    "# batch_nums (batch_size, )\n",
    "batch_nums = tf.range(0, limit=batch_size)\n",
    "# batch_nums (batch_size, 1)\n",
    "batch_nums = tf.expand_dims(batch_nums, 1)\n",
    "# batch_nums (batch_size, 1, 1)\n",
    "batch_nums = tf.expand_dims(batch_nums, 2)\n",
    "# batch_nums (batch_size, dec_len, enc_len) \n",
    "batch_nums = tf.tile(batch_nums, [1, dec_len, enc_len])\n",
    "\n",
    "# (dec_len, )\n",
    "dec_len_nums = tf.range(0, limit=dec_len)\n",
    "# (1, dec_len)\n",
    "dec_len_nums = tf.expand_dims(dec_len_nums, 0)\n",
    "# (1, dec_len, 1)\n",
    "dec_len_nums = tf.expand_dims(dec_len_nums, 2)\n",
    "# tile是用来重复的\n",
    "# dec_len_nums (batch_size, dec_len, enc_len) \n",
    "dec_len_nums = tf.tile(dec_len_nums, [batch_size, 1, enc_len])\n",
    "\n",
    "# (batch_size, 1, enc_len)\n",
    "_enc_batch_extend_vocab_expand = tf.expand_dims(_enc_batch_extend_vocab, 1)\n",
    "_enc_batch_extend_vocab_expand = tf.tile(_enc_batch_extend_vocab_expand, \n",
    "                                         [1, dec_len, 1])\n",
    "\n",
    "indices = tf.stack((batch_nums, \n",
    "                    dec_len_nums, \n",
    "                    _enc_batch_extend_vocab_expand), \n",
    "                   axis=3)\n",
    "\n",
    "# 开始更新\n",
    "scatter = tf.scatter_nd(indices, step3_updates, step3_shape)\n",
    "# scatter.numpy().sum()\n",
    "\n",
    "final_dists3 = scatter + _vocab_dists_extended"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tf2.0]",
   "language": "python",
   "name": "conda-env-tf2.0-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "292.8px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
